import matplotlib

matplotlib.use("agg")
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D


def export_to_pc_batch(dir, pcs, colors=None):
    Path(dir).mkdir(parents=True, exist_ok=True)
    for i, xyz in enumerate(pcs):
        if colors is None:
            color = None
        else:
            color = colors[i]
        pcwrite(os.path.join(dir, "sample_" + str(i) + ".ply"), xyz, color)


def meshwrite(filename, verts, faces, norms, colors):
    """Save a 3D mesh to a polygon .ply file."""
    # Write header
    ply_file = open(filename, "w")
    ply_file.write("ply\n")
    ply_file.write("format ascii 1.0\n")
    ply_file.write("element vertex %d\n" % (verts.shape[0]))
    ply_file.write("property float x\n")
    ply_file.write("property float y\n")
    ply_file.write("property float z\n")
    ply_file.write("property float nx\n")
    ply_file.write("property float ny\n")
    ply_file.write("property float nz\n")
    ply_file.write("property uchar red\n")
    ply_file.write("property uchar green\n")
    ply_file.write("property uchar blue\n")
    ply_file.write("element face %d\n" % (faces.shape[0]))
    ply_file.write("property list uchar int vertex_index\n")
    ply_file.write("end_header\n")

    # Write vertex list
    for i in range(verts.shape[0]):
        ply_file.write(
            "%f %f %f %f %f %f %d %d %d\n"
            % (
                verts[i, 0],
                verts[i, 1],
                verts[i, 2],
                norms[i, 0],
                norms[i, 1],
                norms[i, 2],
                colors[i, 0],
                colors[i, 1],
                colors[i, 2],
            )
        )

    # Write face list
    for i in range(faces.shape[0]):
        ply_file.write("3 %d %d %d\n" % (faces[i, 0], faces[i, 1], faces[i, 2]))

    ply_file.close()


def pcwrite(filename, xyz, rgb=None):
    """Save a point cloud to a polygon .ply file."""
    if rgb is None:
        rgb = np.ones_like(xyz) * 128
    rgb = rgb.astype(np.uint8)

    # Write header
    ply_file = open(filename, "w")
    ply_file.write("ply\n")
    ply_file.write("format ascii 1.0\n")
    ply_file.write("element vertex %d\n" % (xyz.shape[0]))
    ply_file.write("property float x\n")
    ply_file.write("property float y\n")
    ply_file.write("property float z\n")
    ply_file.write("property uchar red\n")
    ply_file.write("property uchar green\n")
    ply_file.write("property uchar blue\n")
    ply_file.write("end_header\n")

    # Write vertex list
    for i in range(xyz.shape[0]):
        ply_file.write(
            "%f %f %f %d %d %d\n"
            % (
                xyz[i, 0],
                xyz[i, 1],
                xyz[i, 2],
                rgb[i, 0],
                rgb[i, 1],
                rgb[i, 2],
            )
        )


"""
Matplotlib Visualization 
"""


def visualize_voxels(out_file, voxels, num_shown=16, threshold=0.5):
    r"""Visualizes voxel data.
    show only first num_shown
    """
    batch_size = voxels.shape[0]
    voxels = voxels.squeeze(1) > threshold

    num_shown = min(num_shown, batch_size)

    n = int(np.sqrt(num_shown))
    fig = plt.figure(figsize=(20, 20))

    for idx, pc in enumerate(voxels[:num_shown]):
        if idx >= n * n:
            break
        pc = voxels[idx]
        ax = fig.add_subplot(n, n, idx + 1, projection="3d")
        ax.voxels(pc, edgecolor="k", facecolors="green", linewidth=0.1, alpha=0.5)
        ax.view_init()
        ax.axis("off")
    plt.savefig(out_file, bbox_inches="tight")
    plt.close()


def visualize_pointcloud(points, normals=None, out_file=None, show=False, elev=30, azim=225):
    r"""Visualizes point cloud data.
    Args:
        points (tensor): point data
        normals (tensor): normal data (if existing)
        out_file (string): output file
        show (bool): whether the plot should be shown
    """
    # Create plot
    fig = plt.figure()
    ax = fig.gca(projection=Axes3D.name)
    ax.scatter(points[:, 2], points[:, 0], points[:, 1])
    if normals is not None:
        ax.quiver(
            points[:, 2],
            points[:, 0],
            points[:, 1],
            normals[:, 2],
            normals[:, 0],
            normals[:, 1],
            length=0.1,
            color="k",
        )
    ax.set_xlabel("Z")
    ax.set_ylabel("X")
    ax.set_zlabel("Y")
    # ax.set_xlim(-0.5, 0.5)
    # ax.set_ylim(-0.5, 0.5)
    # ax.set_zlim(-0.5, 0.5)
    ax.view_init(elev=elev, azim=azim)
    if out_file is not None:
        plt.savefig(out_file)
    if show:
        plt.show()
    plt.close(fig)


def visualize_pointcloud_batch(
    path,
    pointclouds,
    pred_labels=None,
    labels=None,
    categories=None,
    vis_label=False,
    target=None,
    elev=30,
    azim=225,
):
    batch_size = len(pointclouds)
    fig = plt.figure(figsize=(20, 20))

    ncols = int(np.sqrt(batch_size))
    nrows = max(1, (batch_size - 1) // ncols + 1)
    for idx, pc in enumerate(pointclouds):
        if vis_label:
            label = categories[labels[idx].item()]
            pred = categories[pred_labels[idx]]
            colour = "g" if label == pred else "r"
        elif target is None:
            colour = "g"
        else:
            colour = target[idx]
        pc = pc.cpu().numpy()
        # subsample points if more than 10k
        if pc.shape[0] > 10000:
            pc = pc[np.random.choice(pc.shape[0], 10000, replace=False), :]
        # plot and save
        ax = fig.add_subplot(nrows, ncols, idx + 1, projection="3d")
        ax.scatter(pc[:, 0], pc[:, 2], pc[:, 1], c=colour, s=5)
        ax.view_init(elev=elev, azim=azim)
        ax.axis("off")
        if vis_label:
            ax.set_title("GT: {0}\nPred: {1}".format(label, pred))

    # make dir
    os.makedirs(os.path.dirname(path), exist_ok=True)
    plt.savefig(path)
    plt.close(fig)


"""
Plot stats
"""


def plot_stats(output_dir, stats, interval):
    content = stats.keys()
    # f = plt.figure(figsize=(20, len(content) * 5))
    f, axs = plt.subplots(len(content), 1, figsize=(20, len(content) * 5))
    for j, (k, v) in enumerate(stats.items()):
        axs[j].plot(interval, v)
        axs[j].set_ylabel(k)

    f.savefig(os.path.join(output_dir, "stat.pdf"), bbox_inches="tight")
    plt.close(f)
